import random, sys
import gmpy2
from gmpy2 import f_div, mpz, mpz_urandomb, is_prime, random_state, invert, powmod, add, mul, f_mod

from phe import paillier

rand = random_state(random.randrange(sys.maxsize))

def generate_random_number(n):
    """生成一个随机数，范围在 [1, n-1] 之间"""
    r = powmod(mpz(random.randint(1, n - 1)), n, n ** 2)
    #r = mpz(random.randint(1, n - 1))
    return r

def generate_prime(bits):
    """generate an b-bit prime integer"""
    while True:
        possible = mpz(2) ** (bits - 1) + mpz_urandomb(rand, bits - 1)
        if is_prime(possible):
            return possible


# 这里的bits主要是针对n来说，pq应该为bits//2
def keyGeneration(bits):
    """
        密钥生成
        :param length: 二进制位数，默认128
        :return: 公钥sk=n=pq,私钥sk=(λ, v),h用来解密阶段加密用
    """
    p = generate_prime(bits // 2)
    q = generate_prime(bits // 2)
    # print(f'p = {p}')
    # print(f'q = {q}')
    # step2：计算n=pq,λ=lcm(p-1,q-1)
    n = p * q
    λ = gmpy2.lcm(p - 1, q - 1)
    # step3：随机选择多个整数r，h=r^n mod n^2
    # TODO:这里先随机选择一个
    r = mpz(random.randint(1, n - 1))
    # r = 1
    h = powmod(r, n, n ** 2)  # (base ** exponent) % modulus，其中 base 是底数，exponent 是指数，modulus 是模数。
    # step4:引入秘密参数 v=λ^-1 mod n 即λ模n的逆元，λ * v % n = 1
    v = invert(λ, n)
    return n, λ, v, h


def encryption(m, n, h):
    """
        加密阶段
        :param m: 明文
        :param n: 公钥
        :param h: 密钥生成中的h，用来加密
        :return: 密文c
    """
    c = mul(add(mul(m, n), 1), h)  # 密文
    return c


def decryption(c, λ, n, v):
    """
    解密阶段
    :param c: 密文
    :param λ: 私钥λ
    :param n: 公钥
    :param v: 私钥v
    :return:
    """
    x = powmod(c, λ, n ** 2)
    L = f_div(x - 1, n)
    de_m = f_mod(mul(v, L), n)
    return de_m


def encryption_add(n, c1, c2):
    """Add one encrypted integer to another"""
    return powmod(c1 * c2, 1, n**2)


def encryption_add_const(n, m, c):
    """Add a constant to an encrypted integer"""
    return mul(m, add(mul(c, n), 1))


def encryption_mul_const(n, m, c):
    """Multiply an encrypted integer by a constant"""
    #encrypted_m = encryption(m, n, h)  # 将常量 m 加密
    #return encryption_add(n, c, encrypted_m)
    return powmod(m, c, n**2)


def rerandomize(ciphertext, n , h):
    return encryption_add(n , ciphertext, encryption(mpz(0), n,h))

if __name__ == '__main__':
    n, λ, v, h = keyGeneration(1024)
    print(f'n = {n}')
    print(f'λ = {λ}')
    print(f'h = {h}')
    print(f'v = {v}')
    print('-------------------------加解密----------------------')
    m = 30  # 明文
    c = encryption(m, n, h)  # 密文
    print(f'明文m = {m}')
    print(f'密文c = {c}')
    de_m = decryption(c, λ, n, v)
    print(f'解密明文de_m = {de_m}')

    print('-------------------------同态加,密文乘=明文加----------------------')
    m1, m2 = 10, 30
    c1, c2 = encryption(m1, n, h), encryption(m2, n, h)

    c1c2 = encryption_add(n, c1, c2)

    de_mm = decryption(c1c2, λ, n, v)
    print(f'解密明文 de_mm = {de_mm}')

    print('-------------------------密文×常量----------------------')
    const_c = 25
    m3 = 40
    c3 = encryption(m3, n, h)
    c3const_c = encryption_mul_const(n, c3, const_c)
    de_mm = decryption(c3const_c, λ, n, v)
    print(f'解密明文 de_mm = {de_mm}')

    print('-------------------------密文+常量----------------------')
    c3_add_const_c = encryption_add_const(n, c3, const_c)
    de_mm = decryption(c3_add_const_c, λ, n, v)
    print(f'解密明文 de_mm = {de_mm}')

    print('-------------------------密文*随机数----------------------')

    r1 = generate_random_number(n)  # 生成随机数
    r1 = h
    print(f"生成的随机数 r1: {r1}")
    print(f"生成随机数 r1类型：{type(r1)}")
    #r11 = encryption(r1,n,h)
    # 将 m1 加密
    c1 = encryption(mpz(10), n, h)
    print("Dec(c1) = ",decryption(c1,λ, n, v))
    # 乘以随机数
    res = decryption(r1*c1, λ, n, v)
    print(f"Dec(r1*m1) = ", res)